import gc
import logging
import platform
import re
import threading
import time
from collections import Counter, defaultdict
from contextlib import contextmanager
from dataclasses import fields
from typing import Dict, List, Optional
from unittest.mock import MagicMock, patch

import numpy as np
import pytest

import ray
from ray._common.test_utils import wait_for_condition
from ray._private.test_utils import run_string_as_driver
from ray.data._internal.block_batching.iter_batches import BatchIterator
from ray.data._internal.execution.backpressure_policy import (
    ENABLED_BACKPRESSURE_POLICIES_CONFIG_KEY,
)
from ray.data._internal.execution.backpressure_policy.backpressure_policy import (
    BackpressurePolicy,
)
from ray.data._internal.execution.dataset_state import DatasetState
from ray.data._internal.execution.interfaces.common import RuntimeMetricsHistogram
from ray.data._internal.execution.interfaces.op_runtime_metrics import (
    TaskDurationStats,
)
from ray.data._internal.execution.interfaces.physical_operator import PhysicalOperator
from ray.data._internal.execution.streaming_executor import StreamingExecutor
from ray.data._internal.stats import (
    DatasetStats,
    NodeMetrics,
    OperatorStatsSummary,
    _StatsActor,
    get_or_create_stats_actor,
)
from ray.data._internal.util import MemoryProfiler
from ray.data.block import BlockExecStats, BlockStats
from ray.data.context import DataContext
from ray.data.tests.util import column_udf
from ray.tests.conftest import *  # noqa


@pytest.mark.skipif(
    platform.system() != "Linux", reason="MemoryProfiler only supported on Linux"
)
def test_block_exec_stats_max_uss_bytes_with_polling(ray_start_regular_shared):
    array_nbytes = 1024**3  # 1 GiB
    poll_interval_s = 0.01
    with MemoryProfiler(poll_interval_s=poll_interval_s) as profiler:
        array = np.random.randint(0, 256, size=(array_nbytes,), dtype=np.uint8)
        time.sleep(poll_interval_s * 2)

        del array
        gc.collect()

        assert profiler.estimate_max_uss() > array_nbytes


@pytest.mark.skipif(
    platform.system() != "Linux", reason="MemoryProfiler only supported on Linux"
)
def test_block_exec_stats_max_uss_bytes_without_polling(ray_start_regular_shared):
    array_nbytes = 1024**3  # 1 GiB
    with MemoryProfiler(poll_interval_s=None) as profiler:
        _ = np.random.randint(0, 256, size=(array_nbytes,), dtype=np.uint8)

        assert profiler.estimate_max_uss() > array_nbytes


def gen_expected_metrics(
    is_map: bool,
    spilled: bool = False,
    task_backpressure: bool = False,
    task_output_backpressure: bool = False,
    extra_metrics: Optional[List[str]] = None,
):
    if is_map:
        metrics = [
            "'average_num_outputs_per_task': N",
            "'average_num_inputs_per_task': N",
            "'num_output_blocks_per_task_s': N",
            "'average_total_task_completion_time_s': N",
            "'average_task_completion_excl_backpressure_time_s': N",
            "'average_bytes_per_output': N",
            "'obj_store_mem_internal_inqueue': Z",
            "'obj_store_mem_internal_outqueue': Z",
            "'obj_store_mem_pending_task_inputs': Z",
            "'average_bytes_inputs_per_task': N",
            "'average_rows_inputs_per_task': N",
            "'average_bytes_outputs_per_task': N",
            "'average_rows_outputs_per_task': N",
            "'average_max_uss_per_task': H",
            "'num_inputs_received': N",
            "'num_row_inputs_received': N",
            "'bytes_inputs_received': N",
            "'num_task_inputs_processed': N",
            "'bytes_task_inputs_processed': N",
            "'bytes_inputs_of_submitted_tasks': N",
            "'rows_inputs_of_submitted_tasks': N",
            "'num_task_outputs_generated': N",
            "'bytes_task_outputs_generated': N",
            "'rows_task_outputs_generated': N",
            "'row_outputs_taken': N",
            "'block_outputs_taken': N",
            "'num_outputs_taken': N",
            "'bytes_outputs_taken': N",
            "'num_outputs_of_finished_tasks': N",
            "'bytes_outputs_of_finished_tasks': N",
            "'rows_outputs_of_finished_tasks': N",
            "'num_external_inqueue_blocks': Z",
            "'num_external_inqueue_bytes': Z",
            "'num_external_outqueue_blocks': Z",
            "'num_external_outqueue_bytes': Z",
            "'num_tasks_submitted': N",
            "'num_tasks_running': Z",
            "'num_tasks_have_outputs': N",
            "'num_tasks_finished': N",
            "'num_tasks_failed': Z",
            "'block_generation_time': N",
            (
                "'task_submission_backpressure_time': "
                f"{'N' if task_backpressure else 'Z'}"
            ),
            (
                "'task_output_backpressure_time': "
                f"{'N' if task_output_backpressure else 'Z'}"
            ),
            "'task_completion_time_total_s': N",
            "'task_completion_time': (samples: N, avg: N)",
            "'block_completion_time': (samples: N, avg: N)",
            "'task_completion_time_excl_backpressure_s': N",
            "'block_size_bytes': (samples: N, avg: N)",
            "'block_size_rows': (samples: N, avg: N)",
            "'num_alive_actors': Z",
            "'num_restarting_actors': Z",
            "'num_pending_actors': Z",
            "'obj_store_mem_internal_inqueue_blocks': Z",
            "'obj_store_mem_internal_outqueue_blocks': Z",
            "'obj_store_mem_freed': N",
            f"""'obj_store_mem_spilled': {"N" if spilled else "Z"}""",
            "'obj_store_mem_used': A",
            "'cpu_usage': Z",
            "'gpu_usage': Z",
        ]
    else:
        metrics = [
            "'average_num_outputs_per_task': None",
            "'average_num_inputs_per_task': None",
            "'num_output_blocks_per_task_s': None",
            "'average_total_task_completion_time_s': None",
            "'average_task_completion_excl_backpressure_time_s': None",
            "'average_bytes_per_output': None",
            "'obj_store_mem_internal_inqueue': Z",
            "'obj_store_mem_internal_outqueue': Z",
            "'obj_store_mem_pending_task_inputs': Z",
            "'average_bytes_inputs_per_task': None",
            "'average_rows_inputs_per_task': None",
            "'average_bytes_outputs_per_task': None",
            "'average_rows_outputs_per_task': None",
            "'average_max_uss_per_task': H",
            "'num_inputs_received': N",
            "'num_row_inputs_received': N",
            "'bytes_inputs_received': N",
            "'num_task_inputs_processed': Z",
            "'bytes_task_inputs_processed': Z",
            "'bytes_inputs_of_submitted_tasks': Z",
            "'rows_inputs_of_submitted_tasks': Z",
            "'num_task_outputs_generated': Z",
            "'bytes_task_outputs_generated': Z",
            "'rows_task_outputs_generated': Z",
            "'row_outputs_taken': N",
            "'block_outputs_taken': N",
            "'num_outputs_taken': N",
            "'bytes_outputs_taken': N",
            "'num_outputs_of_finished_tasks': Z",
            "'bytes_outputs_of_finished_tasks': Z",
            "'rows_outputs_of_finished_tasks': Z",
            "'num_external_inqueue_blocks': Z",
            "'num_external_inqueue_bytes': Z",
            "'num_external_outqueue_blocks': Z",
            "'num_external_outqueue_bytes': Z",
            "'num_tasks_submitted': Z",
            "'num_tasks_running': Z",
            "'num_tasks_have_outputs': Z",
            "'num_tasks_finished': Z",
            "'num_tasks_failed': Z",
            "'block_generation_time': Z",
            (
                "'task_submission_backpressure_time': "
                f"{'N' if task_backpressure else 'Z'}"
            ),
            (
                "'task_output_backpressure_time': "
                f"{'N' if task_output_backpressure else 'Z'}"
            ),
            "'task_completion_time_total_s': Z",
            "'task_completion_time': (samples: Z, avg: Z)",
            "'block_completion_time': (samples: Z, avg: Z)",
            (
                "'task_completion_time_excl_backpressure_s': "
                f"{'N' if task_backpressure else 'Z'}"
            ),
            "'block_size_bytes': (samples: Z, avg: Z)",
            "'block_size_rows': (samples: Z, avg: Z)",
            "'num_alive_actors': Z",
            "'num_restarting_actors': Z",
            "'num_pending_actors': Z",
            "'obj_store_mem_internal_inqueue_blocks': Z",
            "'obj_store_mem_internal_outqueue_blocks': Z",
            "'obj_store_mem_freed': Z",
            "'obj_store_mem_spilled': Z",
            "'obj_store_mem_used': A",
            "'cpu_usage': Z",
            "'gpu_usage': Z",
        ]
    if extra_metrics:
        metrics.extend(extra_metrics)
    return "{" + ", ".join(metrics) + "}"


def gen_extra_metrics_str(metrics: str, verbose: bool):
    return f"* Extra metrics: {metrics}" + "\n" if verbose else ""


def gen_runtime_metrics_str(op_names: List[str], verbose: bool) -> str:
    if not verbose:
        return ""
    out = "\nRuntime Metrics:\n"
    for op in op_names + ["Scheduling", "Total"]:
        out += f"* {op}: T (N%)\n"
    return out


STANDARD_EXTRA_METRICS = gen_expected_metrics(
    is_map=True,
    spilled=False,
    extra_metrics=[
        "'ray_remote_args': {'num_cpus': N, 'scheduling_strategy': 'SPREAD'}"
    ],
)

STANDARD_EXTRA_METRICS_TASK_BACKPRESSURE = gen_expected_metrics(
    is_map=True,
    spilled=False,
    task_backpressure=True,
    extra_metrics=[
        "'ray_remote_args': {'num_cpus': N, 'scheduling_strategy': 'SPREAD'}"
    ],
)

MEM_SPILLED_EXTRA_METRICS = gen_expected_metrics(
    is_map=True,
    spilled=True,
    extra_metrics=[
        "'ray_remote_args': {'num_cpus': N, 'scheduling_strategy': 'SPREAD'}"
    ],
)

MEM_SPILLED_EXTRA_METRICS_TASK_BACKPRESSURE = gen_expected_metrics(
    is_map=True,
    spilled=True,
    task_backpressure=True,
    extra_metrics=[
        "'ray_remote_args': {'num_cpus': N, 'scheduling_strategy': 'SPREAD'}"
    ],
)


CLUSTER_MEMORY_STATS = """
Cluster memory:
* Spilled to disk: M
* Restored from disk: M
"""

DATASET_MEMORY_STATS = """
Dataset memory:
* Spilled to disk: M
"""

EXECUTION_STRING = "N tasks executed, N blocks produced in T"


def canonicalize(
    stats: str,
    filter_global_stats: bool = True,
) -> str:
    # Dataset UUID expression.
    canonicalized_stats = re.sub(r"([a-f\d]{32})", "U", stats)
    # Time expressions.
    canonicalized_stats = re.sub(r"[0-9\.]+(ms|us|s)", "T", canonicalized_stats)
    # Memory expressions.
    canonicalized_stats = re.sub(r"[0-9\.]+(B|MB|GB)", "M", canonicalized_stats)
    # Histogram expressions.
    canonicalized_stats = re.sub(
        r"\(samples: 0, avg: 0.00\)", "(samples: Z, avg: Z)", canonicalized_stats
    )
    canonicalized_stats = re.sub(
        r"\(samples: \d+, avg: \d+\.\d+\)", "(samples: N, avg: N)", canonicalized_stats
    )
    # For obj_store_mem_used, the value can be zero or positive, depending on the run.
    # Replace with A to avoid test flakiness.
    canonicalized_stats = re.sub(
        r"(obj_store_mem_used: |'obj_store_mem_used': )\d+(\.\d+)?",
        # Replaces the number with 'A' while keeping the key prefix intact.
        r"\g<1>A",
        canonicalized_stats,
    )
    # Handle floats in (0, 1)
    canonicalized_stats = re.sub(r" (0\.0*[1-9][0-9]*)", " N", canonicalized_stats)
    # Replace input rows value (0 or non-0) with 'N' while keeping key prefix
    canonicalized_stats = re.sub(
        r"(Total input num rows: )\d+(\.\d+)?", r"\g<1>N", canonicalized_stats
    )
    # Replace output rows value (0 or non-0) with 'N' while keeping key prefix
    canonicalized_stats = re.sub(
        r"(Total output num rows: )\d+(\.\d+)?", r"\g<1>N", canonicalized_stats
    )
    # Handle zero values specially so we can check for missing values.
    canonicalized_stats = re.sub(r" [0]+(\.[0])?", " Z", canonicalized_stats)
    # Scientific notation for small or large numbers
    canonicalized_stats = re.sub(r"\d+(\.\d+)?[eE][-+]?\d+", "N", canonicalized_stats)
    # Other numerics.
    canonicalized_stats = re.sub(r"[0-9]+(\.[0-9]+)?", "N", canonicalized_stats)
    # Replace tabs with spaces.
    canonicalized_stats = re.sub("\t", "    ", canonicalized_stats)

    # Ray might not be able to estimate the peak heap memory usage on some platforms.
    # In those cases, the memory estimate could be 0 or None.
    canonicalized_stats = re.sub(
        r"Peak heap memory usage \(MiB\): (N|Z) min, (N|Z) max, (N|Z) mean",
        "Peak heap memory usage (MiB): H min, H max, H mean",
        canonicalized_stats,
    )
    canonicalized_stats = re.sub(
        r"(average_max_uss_per_task:|'average_max_uss_per_task':) (?:N|Z|None)\b",
        r"\g<1> H",
        canonicalized_stats,
    )

    if filter_global_stats:
        canonicalized_stats = canonicalized_stats.replace(CLUSTER_MEMORY_STATS, "")
        canonicalized_stats = canonicalized_stats.replace(DATASET_MEMORY_STATS, "")
    return canonicalized_stats


def dummy_map_batches(x):
    """Dummy function used in calls to map_batches below."""
    return x


def dummy_map_batches_sleep(n):
    """Function used to create a function that sleeps for n seconds
    to be used in map_batches below."""

    def f(x):
        time.sleep(n)
        return x

    return f


@contextmanager
def patch_update_stats_actor():
    with patch(
        "ray.data._internal.stats._StatsManager.update_execution_metrics"
    ) as update_fn:
        yield update_fn


@contextmanager
def patch_update_stats_actor_iter():
    with patch(
        "ray.data._internal.stats._StatsManager.update_iteration_metrics"
    ) as update_fn:
        yield update_fn


def test_streaming_split_stats(ray_start_regular_shared, restore_data_context):
    context = DataContext.get_current()
    context.verbose_stats_logs = True
    ds = ray.data.range(1000, override_num_blocks=10)
    it = ds.map_batches(dummy_map_batches).streaming_split(1)[0]
    list(it.iter_batches())
    stats = it.stats()
    extra_metrics_1 = STANDARD_EXTRA_METRICS_TASK_BACKPRESSURE  # .replace(
    #     "'obj_store_mem_used': A", "'obj_store_mem_used': Z"
    # )
    extra_metrics_2 = gen_expected_metrics(
        is_map=False,
        extra_metrics=["'num_output_N': N", "'output_splitter_overhead_time': N"],
    )
    assert (
        canonicalize(stats)
        == f"""Operator N ReadRange->MapBatches(dummy_map_batches): {EXECUTION_STRING}
* Remote wall time: T min, T max, T mean, T total
* Remote cpu time: T min, T max, T mean, T total
* UDF time: T min, T max, T mean, T total
* Peak heap memory usage (MiB): H min, H max, H mean
* Output num rows per block: N min, N max, N mean, N total
* Output size bytes per block: N min, N max, N mean, N total
* Output rows per task: N min, N max, N mean, N tasks used
* Tasks per node: N min, N max, N mean; N nodes used
* Operator throughput:
    * Total input num rows: N rows
    * Total output num rows: N rows
    * Ray Data throughput: N rows/s
    * Estimated single task throughput: N rows/s
* Extra metrics: {extra_metrics_1}

Operator N split(N, equal=False): \n"""
        # Workaround to preserve trailing whitespace in the above line without
        # causing linter failures.
        f"""* Extra metrics: {extra_metrics_2}\n"""
        """
Dataset iterator time breakdown:
* Total time overall: T
    * Total time in Ray Data iterator initialization code: T
    * Total time user thread is blocked by Ray Data iter_batches: T
    * Total time spent waiting for the first batch after starting iteration: T
    * Total execution time for user thread: T
* Batch iteration time breakdown (summed across prefetch threads):
    * In get RefBundles: T min, T max, T avg, T total
    * In ray.get(): T min, T max, T avg, T total
    * In batch creation: T min, T max, T avg, T total
    * In batch formatting: T min, T max, T avg, T total
Streaming split coordinator overhead time: T
"""
        f"{gen_runtime_metrics_str(['ReadRange->MapBatches(dummy_map_batches)', 'split(N, equal=False)'], True)}"  # noqa: E501
    )


@pytest.mark.parametrize("verbose_stats_logs", [True, False])
def test_dataset_stats_basic(
    ray_start_regular_shared,
    enable_auto_log_stats,
    verbose_stats_logs,
    restore_data_context,
):
    context = DataContext.get_current()
    context.verbose_stats_logs = verbose_stats_logs
    logger = logging.getLogger("ray.data._internal.execution.streaming_executor")

    with patch.object(logger, "info") as mock_logger:
        ds = ray.data.range(1000, override_num_blocks=10)
        ds = ds.map_batches(dummy_map_batches).materialize()

        if enable_auto_log_stats:
            logger_args, logger_kwargs = mock_logger.call_args_list[-2]

            assert canonicalize(logger_args[0]) == (
                f"Operator N ReadRange->MapBatches(dummy_map_batches): "
                f"{EXECUTION_STRING}\n"
                f"* Remote wall time: T min, T max, T mean, T total\n"
                f"* Remote cpu time: T min, T max, T mean, T total\n"
                f"* UDF time: T min, T max, T mean, T total\n"
                f"* Peak heap memory usage (MiB): H min, H max, H mean\n"
                f"* Output num rows per block: N min, N max, N mean, N total\n"
                f"* Output size bytes per block: N min, N max, N mean, N total\n"
                f"* Output rows per task: N min, N max, N mean, N tasks used\n"
                f"* Tasks per node: N min, N max, N mean; N nodes used\n"
                f"* Operator throughput:\n"
                f"    * Total input num rows: N rows\n"
                f"    * Total output num rows: N rows\n"
                f"    * Ray Data throughput: N rows/s\n"
                f"    * Estimated single task throughput: N rows/s\n"
                f"{gen_extra_metrics_str(STANDARD_EXTRA_METRICS_TASK_BACKPRESSURE, verbose_stats_logs)}"  # noqa: E501
                f"\n"
                f"Dataset throughput:\n"
                f"    * Ray Data throughput: N rows/s\n"
                f"{gen_runtime_metrics_str(['ReadRange->MapBatches(dummy_map_batches)'], verbose_stats_logs)}"  # noqa: E501
            )

        ds = ds.map(dummy_map_batches).materialize()
        if enable_auto_log_stats:
            logger_args, logger_kwargs = mock_logger.call_args_list[-2]

            assert canonicalize(logger_args[0]) == (
                f"Operator N Map(dummy_map_batches): {EXECUTION_STRING}\n"
                f"* Remote wall time: T min, T max, T mean, T total\n"
                f"* Remote cpu time: T min, T max, T mean, T total\n"
                f"* UDF time: T min, T max, T mean, T total\n"
                f"* Peak heap memory usage (MiB): H min, H max, H mean\n"
                f"* Output num rows per block: N min, N max, N mean, N total\n"
                f"* Output size bytes per block: N min, N max, N mean, N total\n"
                f"* Output rows per task: N min, N max, N mean, N tasks used\n"
                f"* Tasks per node: N min, N max, N mean; N nodes used\n"
                f"* Operator throughput:\n"
                f"    * Total input num rows: N rows\n"
                f"    * Total output num rows: N rows\n"
                f"    * Ray Data throughput: N rows/s\n"
                f"    * Estimated single task throughput: N rows/s\n"
                f"{gen_extra_metrics_str(STANDARD_EXTRA_METRICS_TASK_BACKPRESSURE, verbose_stats_logs)}"  # noqa: E501
                f"\n"
                f"Dataset throughput:\n"
                f"    * Ray Data throughput: N rows/s\n"
                f"{gen_runtime_metrics_str(['ReadRange->MapBatches(dummy_map_batches)','Map(dummy_map_batches)'], verbose_stats_logs)}"  # noqa: E501
            )

    for batch in ds.iter_batches():
        pass
    stats = canonicalize(ds.materialize().stats())

    extra_metrics = gen_extra_metrics_str(
        STANDARD_EXTRA_METRICS_TASK_BACKPRESSURE,
        verbose_stats_logs,
    )

    assert stats == (
        f"Operator N ReadRange->MapBatches(dummy_map_batches): {EXECUTION_STRING}\n"
        f"* Remote wall time: T min, T max, T mean, T total\n"
        f"* Remote cpu time: T min, T max, T mean, T total\n"
        f"* UDF time: T min, T max, T mean, T total\n"
        f"* Peak heap memory usage (MiB): H min, H max, H mean\n"
        f"* Output num rows per block: N min, N max, N mean, N total\n"
        f"* Output size bytes per block: N min, N max, N mean, N total\n"
        f"* Output rows per task: N min, N max, N mean, N tasks used\n"
        f"* Tasks per node: N min, N max, N mean; N nodes used\n"
        f"* Operator throughput:\n"
        f"    * Total input num rows: N rows\n"
        f"    * Total output num rows: N rows\n"
        f"    * Ray Data throughput: N rows/s\n"
        f"    * Estimated single task throughput: N rows/s\n"
        f"{extra_metrics}\n"
        f"Operator N Map(dummy_map_batches): {EXECUTION_STRING}\n"
        f"* Remote wall time: T min, T max, T mean, T total\n"
        f"* Remote cpu time: T min, T max, T mean, T total\n"
        f"* UDF time: T min, T max, T mean, T total\n"
        f"* Peak heap memory usage (MiB): H min, H max, H mean\n"
        f"* Output num rows per block: N min, N max, N mean, N total\n"
        f"* Output size bytes per block: N min, N max, N mean, N total\n"
        f"* Output rows per task: N min, N max, N mean, N tasks used\n"
        f"* Tasks per node: N min, N max, N mean; N nodes used\n"
        f"* Operator throughput:\n"
        f"    * Total input num rows: N rows\n"
        f"    * Total output num rows: N rows\n"
        f"    * Ray Data throughput: N rows/s\n"
        f"    * Estimated single task throughput: N rows/s\n"
        f"{extra_metrics}\n"
        f"Dataset iterator time breakdown:\n"
        f"* Total time overall: T\n"
        f"    * Total time in Ray Data iterator initialization code: T\n"
        f"    * Total time user thread is blocked by Ray Data iter_batches: T\n"
        f"    * Total time spent waiting for the first batch after starting iteration: T\n"
        f"    * Total execution time for user thread: T\n"
        f"* Batch iteration time breakdown (summed across prefetch threads):\n"
        f"    * In get RefBundles: T min, T max, T avg, T total\n"
        f"    * In ray.get(): T min, T max, T avg, T total\n"
        f"    * In batch creation: T min, T max, T avg, T total\n"
        f"    * In batch formatting: T min, T max, T avg, T total\n"
        f"\n"
        f"Dataset throughput:\n"
        f"    * Ray Data throughput: N rows/s\n"
        f"{gen_runtime_metrics_str(['ReadRange->MapBatches(dummy_map_batches)','Map(dummy_map_batches)'], verbose_stats_logs)}"  # noqa: E501
    )


def test_block_location_nums(ray_start_regular_shared, restore_data_context):
    context = DataContext.get_current()
    context.enable_get_object_locations_for_metrics = True
    ds = ray.data.range(1000, override_num_blocks=10)
    ds = ds.map_batches(dummy_map_batches).materialize()

    for batch in ds.iter_batches():
        pass
    stats = canonicalize(ds.materialize().stats())

    assert stats == (
        f"Operator N ReadRange->MapBatches(dummy_map_batches): {EXECUTION_STRING}\n"
        f"* Remote wall time: T min, T max, T mean, T total\n"
        f"* Remote cpu time: T min, T max, T mean, T total\n"
        f"* UDF time: T min, T max, T mean, T total\n"
        f"* Peak heap memory usage (MiB): H min, H max, H mean\n"
        f"* Output num rows per block: N min, N max, N mean, N total\n"
        f"* Output size bytes per block: N min, N max, N mean, N total\n"
        f"* Output rows per task: N min, N max, N mean, N tasks used\n"
        f"* Tasks per node: N min, N max, N mean; N nodes used\n"
        f"* Operator throughput:\n"
        f"    * Total input num rows: N rows\n"
        f"    * Total output num rows: N rows\n"
        f"    * Ray Data throughput: N rows/s\n"
        f"    * Estimated single task throughput: N rows/s\n"
        f"\n"
        f"Dataset iterator time breakdown:\n"
        f"* Total time overall: T\n"
        f"    * Total time in Ray Data iterator initialization code: T\n"
        f"    * Total time user thread is blocked by Ray Data iter_batches: T\n"
        f"    * Total time spent waiting for the first batch after starting iteration: T\n"
        f"    * Total execution time for user thread: T\n"
        f"* Batch iteration time breakdown (summed across prefetch threads):\n"
        f"    * In get RefBundles: T min, T max, T avg, T total\n"
        f"    * In ray.get(): T min, T max, T avg, T total\n"
        f"    * In batch creation: T min, T max, T avg, T total\n"
        f"    * In batch formatting: T min, T max, T avg, T total\n"
        f"Block locations:\n"
        f"    * Num blocks local: Z\n"
        f"    * Num blocks remote: Z\n"
        f"    * Num blocks unknown location: N\n"
        f"\n"
        f"Dataset throughput:\n"
        f"    * Ray Data throughput: N rows/s\n"
    )


def test_dataset__repr__(ray_start_regular_shared, restore_data_context):
    context = DataContext.get_current()
    context.enable_get_object_locations_for_metrics = True
    n = 100
    ds = ray.data.range(n)
    assert len(ds.take_all()) == n
    ds = ds.materialize()

    expected_stats = (
        "DatasetStatsSummary(\n"
        "   dataset_uuid=N,\n"
        "   base_name=ReadRange,\n"
        "   number=N,\n"
        "   extra_metrics={\n"
        "      average_num_outputs_per_task: N,\n"
        "      average_num_inputs_per_task: N,\n"
        "      num_output_blocks_per_task_s: N,\n"
        "      average_total_task_completion_time_s: N,\n"
        "      average_task_completion_excl_backpressure_time_s: N,\n"
        "      average_bytes_per_output: N,\n"
        "      obj_store_mem_internal_inqueue: Z,\n"
        "      obj_store_mem_internal_outqueue: Z,\n"
        "      obj_store_mem_pending_task_inputs: Z,\n"
        "      average_bytes_inputs_per_task: N,\n"
        "      average_rows_inputs_per_task: N,\n"
        "      average_bytes_outputs_per_task: N,\n"
        "      average_rows_outputs_per_task: N,\n"
        "      average_max_uss_per_task: H,\n"
        "      num_inputs_received: N,\n"
        "      num_row_inputs_received: N,\n"
        "      bytes_inputs_received: N,\n"
        "      num_task_inputs_processed: N,\n"
        "      bytes_task_inputs_processed: N,\n"
        "      bytes_inputs_of_submitted_tasks: N,\n"
        "      rows_inputs_of_submitted_tasks: N,\n"
        "      num_task_outputs_generated: N,\n"
        "      bytes_task_outputs_generated: N,\n"
        "      rows_task_outputs_generated: N,\n"
        "      row_outputs_taken: N,\n"
        "      block_outputs_taken: N,\n"
        "      num_outputs_taken: N,\n"
        "      bytes_outputs_taken: N,\n"
        "      num_outputs_of_finished_tasks: N,\n"
        "      bytes_outputs_of_finished_tasks: N,\n"
        "      rows_outputs_of_finished_tasks: N,\n"
        "      num_external_inqueue_blocks: Z,\n"
        "      num_external_inqueue_bytes: Z,\n"
        "      num_external_outqueue_blocks: Z,\n"
        "      num_external_outqueue_bytes: Z,\n"
        "      num_tasks_submitted: N,\n"
        "      num_tasks_running: Z,\n"
        "      num_tasks_have_outputs: N,\n"
        "      num_tasks_finished: N,\n"
        "      num_tasks_failed: Z,\n"
        "      block_generation_time: N,\n"
        "      task_submission_backpressure_time: N,\n"
        "      task_output_backpressure_time: Z,\n"
        "      task_completion_time_total_s: N,\n"
        "      task_completion_time: (samples: N, avg: N),\n"
        "      block_completion_time: (samples: N, avg: N),\n"
        "      task_completion_time_excl_backpressure_s: N,\n"
        "      block_size_bytes: (samples: N, avg: N),\n"
        "      block_size_rows: (samples: N, avg: N),\n"
        "      num_alive_actors: Z,\n"
        "      num_restarting_actors: Z,\n"
        "      num_pending_actors: Z,\n"
        "      obj_store_mem_internal_inqueue_blocks: Z,\n"
        "      obj_store_mem_internal_outqueue_blocks: Z,\n"
        "      obj_store_mem_freed: N,\n"
        "      obj_store_mem_spilled: Z,\n"
        "      obj_store_mem_used: A,\n"
        "      cpu_usage: Z,\n"
        "      gpu_usage: Z,\n"
        "      ray_remote_args: {'num_cpus': N, 'scheduling_strategy': 'SPREAD'},\n"
        "   },\n"
        "   operators_stats=[\n"
        "      OperatorStatsSummary(\n"
        "         operator_name='ReadRange',\n"
        "         is_suboperator=False,\n"
        "         time_total_s=T,\n"
        f"         block_execution_summary_str={EXECUTION_STRING}\n"
        "         wall_time={'min': 'T', 'max': 'T', 'mean': 'T', 'sum': 'T'},\n"
        "         cpu_time={'min': 'T', 'max': 'T', 'mean': 'T', 'sum': 'T'},\n"
        "         memory={'min': 'T', 'max': 'T', 'mean': 'T'},\n"
        "         output_num_rows={'min': 'T', 'max': 'T', 'mean': 'T', 'sum': 'T'},\n"
        "         output_size_bytes={'min': 'T', 'max': 'T', 'mean': 'T', 'sum': 'T'},\n"  # noqa: E501
        "         node_count={'min': 'T', 'max': 'T', 'mean': 'T', 'count': 'T'},\n"
        "      ),\n"
        "   ],\n"
        "   iter_stats=IterStatsSummary(\n"
        "      wait_time=T,\n"
        "      get_ref_bundles_time=T,\n"
        "      get_time=T,\n"
        "      iter_blocks_local=None,\n"
        "      iter_blocks_remote=None,\n"
        "      iter_unknown_location=None,\n"
        "      iter_prefetched_bytes=None,\n"
        "      next_time=T,\n"
        "      format_time=T,\n"
        "      user_time=T,\n"
        "      total_time=T,\n"
        "   ),\n"
        "   global_bytes_spilled=M,\n"
        "   global_bytes_restored=M,\n"
        "   dataset_bytes_spilled=M,\n"
        "   parents=[\n"
        "      DatasetStatsSummary(\n"
        "         dataset_uuid=unknown_uuid,\n"
        "         base_name=None,\n"
        "         number=N,\n"
        "         extra_metrics={},\n"
        "         operators_stats=[],\n"
        "         iter_stats=IterStatsSummary(\n"
        "            wait_time=T,\n"
        "            get_ref_bundles_time=T,\n"
        "            get_time=T,\n"
        "            iter_blocks_local=None,\n"
        "            iter_blocks_remote=None,\n"
        "            iter_unknown_location=None,\n"
        "            iter_prefetched_bytes=None,\n"
        "            next_time=T,\n"
        "            format_time=T,\n"
        "            user_time=T,\n"
        "            total_time=T,\n"
        "         ),\n"
        "         global_bytes_spilled=M,\n"
        "         global_bytes_restored=M,\n"
        "         dataset_bytes_spilled=M,\n"
        "         parents=[],\n"
        "      ),\n"
        "   ],\n"
        ")"
    )

    def check_stats():
        stats = canonicalize(repr(ds._plan.stats().to_summary()))
        assert stats == expected_stats, stats
        return True

    # TODO(hchen): The reason why `wait_for_condition` is needed here is because
    # `to_summary` depends on an external actor (_StatsActor) that records stats
    # asynchronously. This makes the behavior non-deterministic.
    # See the TODO in `to_summary`.
    # We should make it deterministic and refine this test.
    wait_for_condition(
        check_stats,
        timeout=10,
        retry_interval_ms=1000,
    )

    ds2 = ds.map_batches(lambda x: x).materialize()
    assert len(ds2.take_all()) == n
    expected_stats2 = (
        "DatasetStatsSummary(\n"
        "   dataset_uuid=N,\n"
        "   base_name=MapBatches(<lambda>),\n"
        "   number=N,\n"
        "   extra_metrics={\n"
        "      average_num_outputs_per_task: N,\n"
        "      average_num_inputs_per_task: N,\n"
        "      num_output_blocks_per_task_s: N,\n"
        "      average_total_task_completion_time_s: N,\n"
        "      average_task_completion_excl_backpressure_time_s: N,\n"
        "      average_bytes_per_output: N,\n"
        "      obj_store_mem_internal_inqueue: Z,\n"
        "      obj_store_mem_internal_outqueue: Z,\n"
        "      obj_store_mem_pending_task_inputs: Z,\n"
        "      average_bytes_inputs_per_task: N,\n"
        "      average_rows_inputs_per_task: N,\n"
        "      average_bytes_outputs_per_task: N,\n"
        "      average_rows_outputs_per_task: N,\n"
        "      average_max_uss_per_task: H,\n"
        "      num_inputs_received: N,\n"
        "      num_row_inputs_received: N,\n"
        "      bytes_inputs_received: N,\n"
        "      num_task_inputs_processed: N,\n"
        "      bytes_task_inputs_processed: N,\n"
        "      bytes_inputs_of_submitted_tasks: N,\n"
        "      rows_inputs_of_submitted_tasks: N,\n"
        "      num_task_outputs_generated: N,\n"
        "      bytes_task_outputs_generated: N,\n"
        "      rows_task_outputs_generated: N,\n"
        "      row_outputs_taken: N,\n"
        "      block_outputs_taken: N,\n"
        "      num_outputs_taken: N,\n"
        "      bytes_outputs_taken: N,\n"
        "      num_outputs_of_finished_tasks: N,\n"
        "      bytes_outputs_of_finished_tasks: N,\n"
        "      rows_outputs_of_finished_tasks: N,\n"
        "      num_external_inqueue_blocks: Z,\n"
        "      num_external_inqueue_bytes: Z,\n"
        "      num_external_outqueue_blocks: Z,\n"
        "      num_external_outqueue_bytes: Z,\n"
        "      num_tasks_submitted: N,\n"
        "      num_tasks_running: Z,\n"
        "      num_tasks_have_outputs: N,\n"
        "      num_tasks_finished: N,\n"
        "      num_tasks_failed: Z,\n"
        "      block_generation_time: N,\n"
        "      task_submission_backpressure_time: N,\n"
        "      task_output_backpressure_time: Z,\n"
        "      task_completion_time_total_s: N,\n"
        "      task_completion_time: (samples: N, avg: N),\n"
        "      block_completion_time: (samples: N, avg: N),\n"
        "      task_completion_time_excl_backpressure_s: N,\n"
        "      block_size_bytes: (samples: N, avg: N),\n"
        "      block_size_rows: (samples: N, avg: N),\n"
        "      num_alive_actors: Z,\n"
        "      num_restarting_actors: Z,\n"
        "      num_pending_actors: Z,\n"
        "      obj_store_mem_internal_inqueue_blocks: Z,\n"
        "      obj_store_mem_internal_outqueue_blocks: Z,\n"
        "      obj_store_mem_freed: N,\n"
        "      obj_store_mem_spilled: Z,\n"
        "      obj_store_mem_used: A,\n"
        "      cpu_usage: Z,\n"
        "      gpu_usage: Z,\n"
        "      ray_remote_args: {'num_cpus': N, 'scheduling_strategy': 'SPREAD'},\n"
        "   },\n"
        "   operators_stats=[\n"
        "      OperatorStatsSummary(\n"
        "         operator_name='MapBatches(<lambda>)',\n"
        "         is_suboperator=False,\n"
        "         time_total_s=T,\n"
        f"         block_execution_summary_str={EXECUTION_STRING}\n"
        "         wall_time={'min': 'T', 'max': 'T', 'mean': 'T', 'sum': 'T'},\n"
        "         cpu_time={'min': 'T', 'max': 'T', 'mean': 'T', 'sum': 'T'},\n"
        "         memory={'min': 'T', 'max': 'T', 'mean': 'T'},\n"
        "         output_num_rows={'min': 'T', 'max': 'T', 'mean': 'T', 'sum': 'T'},\n"
        "         output_size_bytes={'min': 'T', 'max': 'T', 'mean': 'T', 'sum': 'T'},\n"  # noqa: E501
        "         node_count={'min': 'T', 'max': 'T', 'mean': 'T', 'count': 'T'},\n"
        "      ),\n"
        "   ],\n"
        "   iter_stats=IterStatsSummary(\n"
        "      wait_time=T,\n"
        "      get_ref_bundles_time=T,\n"
        "      get_time=T,\n"
        "      iter_blocks_local=None,\n"
        "      iter_blocks_remote=None,\n"
        "      iter_unknown_location=N,\n"
        "      iter_prefetched_bytes=None,\n"
        "      next_time=T,\n"
        "      format_time=T,\n"
        "      user_time=T,\n"
        "      total_time=T,\n"
        "   ),\n"
        "   global_bytes_spilled=M,\n"
        "   global_bytes_restored=M,\n"
        "   dataset_bytes_spilled=M,\n"
        "   parents=[\n"
        "      DatasetStatsSummary(\n"
        "         dataset_uuid=N,\n"
        "         base_name=ReadRange,\n"
        "         number=N,\n"
        "         extra_metrics={\n"
        "            average_num_outputs_per_task: N,\n"
        "            average_num_inputs_per_task: N,\n"
        "            num_output_blocks_per_task_s: N,\n"
        "            average_total_task_completion_time_s: N,\n"
        "            average_task_completion_excl_backpressure_time_s: N,\n"
        "            average_bytes_per_output: N,\n"
        "            obj_store_mem_internal_inqueue: Z,\n"
        "            obj_store_mem_internal_outqueue: Z,\n"
        "            obj_store_mem_pending_task_inputs: Z,\n"
        "            average_bytes_inputs_per_task: N,\n"
        "            average_rows_inputs_per_task: N,\n"
        "            average_bytes_outputs_per_task: N,\n"
        "            average_rows_outputs_per_task: N,\n"
        "            average_max_uss_per_task: H,\n"
        "            num_inputs_received: N,\n"
        "            num_row_inputs_received: N,\n"
        "            bytes_inputs_received: N,\n"
        "            num_task_inputs_processed: N,\n"
        "            bytes_task_inputs_processed: N,\n"
        "            bytes_inputs_of_submitted_tasks: N,\n"
        "            rows_inputs_of_submitted_tasks: N,\n"
        "            num_task_outputs_generated: N,\n"
        "            bytes_task_outputs_generated: N,\n"
        "            rows_task_outputs_generated: N,\n"
        "            row_outputs_taken: N,\n"
        "            block_outputs_taken: N,\n"
        "            num_outputs_taken: N,\n"
        "            bytes_outputs_taken: N,\n"
        "            num_outputs_of_finished_tasks: N,\n"
        "            bytes_outputs_of_finished_tasks: N,\n"
        "            rows_outputs_of_finished_tasks: N,\n"
        "            num_external_inqueue_blocks: Z,\n"
        "            num_external_inqueue_bytes: Z,\n"
        "            num_external_outqueue_blocks: Z,\n"
        "            num_external_outqueue_bytes: Z,\n"
        "            num_tasks_submitted: N,\n"
        "            num_tasks_running: Z,\n"
        "            num_tasks_have_outputs: N,\n"
        "            num_tasks_finished: N,\n"
        "            num_tasks_failed: Z,\n"
        "            block_generation_time: N,\n"
        "            task_submission_backpressure_time: N,\n"
        "            task_output_backpressure_time: Z,\n"
        "            task_completion_time_total_s: N,\n"
        "            task_completion_time: (samples: N, avg: N),\n"
        "            block_completion_time: (samples: N, avg: N),\n"
        "            task_completion_time_excl_backpressure_s: N,\n"
        "            block_size_bytes: (samples: N, avg: N),\n"
        "            block_size_rows: (samples: N, avg: N),\n"
        "            num_alive_actors: Z,\n"
        "            num_restarting_actors: Z,\n"
        "            num_pending_actors: Z,\n"
        "            obj_store_mem_internal_inqueue_blocks: Z,\n"
        "            obj_store_mem_internal_outqueue_blocks: Z,\n"
        "            obj_store_mem_freed: N,\n"
        "            obj_store_mem_spilled: Z,\n"
        "            obj_store_mem_used: A,\n"
        "            cpu_usage: Z,\n"
        "            gpu_usage: Z,\n"
        "            ray_remote_args: {'num_cpus': N, 'scheduling_strategy': 'SPREAD'},\n"  # noqa: E501
        "         },\n"
        "         operators_stats=[\n"
        "            OperatorStatsSummary(\n"
        "               operator_name='ReadRange',\n"
        "               is_suboperator=False,\n"
        "               time_total_s=T,\n"
        f"               block_execution_summary_str={EXECUTION_STRING}\n"
        "               wall_time={'min': 'T', 'max': 'T', 'mean': 'T', 'sum': 'T'},\n"
        "               cpu_time={'min': 'T', 'max': 'T', 'mean': 'T', 'sum': 'T'},\n"
        "               memory={'min': 'T', 'max': 'T', 'mean': 'T'},\n"
        "               output_num_rows={'min': 'T', 'max': 'T', 'mean': 'T', 'sum': 'T'},\n"  # noqa: E501
        "               output_size_bytes={'min': 'T', 'max': 'T', 'mean': 'T', 'sum': 'T'},\n"  # noqa: E501
        "               node_count={'min': 'T', 'max': 'T', 'mean': 'T', 'count': 'T'},\n"  # noqa: E501
        "            ),\n"
        "         ],\n"
        "         iter_stats=IterStatsSummary(\n"
        "            wait_time=T,\n"
        "            get_ref_bundles_time=T,\n"
        "            get_time=T,\n"
        "            iter_blocks_local=None,\n"
        "            iter_blocks_remote=None,\n"
        "            iter_unknown_location=None,\n"
        "            iter_prefetched_bytes=None,\n"
        "            next_time=T,\n"
        "            format_time=T,\n"
        "            user_time=T,\n"
        "            total_time=T,\n"
        "         ),\n"
        "         global_bytes_spilled=M,\n"
        "         global_bytes_restored=M,\n"
        "         dataset_bytes_spilled=M,\n"
        "         parents=[\n"
        "            DatasetStatsSummary(\n"
        "               dataset_uuid=unknown_uuid,\n"
        "               base_name=None,\n"
        "               number=N,\n"
        "               extra_metrics={},\n"
        "               operators_stats=[],\n"
        "               iter_stats=IterStatsSummary(\n"
        "                  wait_time=T,\n"
        "                  get_ref_bundles_time=T,\n"
        "                  get_time=T,\n"
        "                  iter_blocks_local=None,\n"
        "                  iter_blocks_remote=None,\n"
        "                  iter_unknown_location=None,\n"
        "                  iter_prefetched_bytes=None,\n"
        "                  next_time=T,\n"
        "                  format_time=T,\n"
        "                  user_time=T,\n"
        "                  total_time=T,\n"
        "               ),\n"
        "               global_bytes_spilled=M,\n"
        "               global_bytes_restored=M,\n"
        "               dataset_bytes_spilled=M,\n"
        "               parents=[],\n"
        "            ),\n"
        "         ],\n"
        "      ),\n"
        "   ],\n"
        ")"
    )

    def check_stats2():
        stats = canonicalize(repr(ds2._plan.stats().to_summary()))
        assert stats == expected_stats2
        return True

    wait_for_condition(
        check_stats2,
        timeout=10,
        retry_interval_ms=1000,
    )


def test_dataset_stats_shuffle(ray_start_regular_shared):
    ds = ray.data.range(1000, override_num_blocks=10)
    ds = ds.random_shuffle().repartition(1, shuffle=True)
    stats = canonicalize(ds.materialize().stats())
    assert (
        stats
        == """Operator N ReadRange->RandomShuffle: executed in T

    Suboperator Z ReadRange->RandomShuffleMap: N tasks executed, N blocks produced
    * Remote wall time: T min, T max, T mean, T total
    * Remote cpu time: T min, T max, T mean, T total
    * UDF time: T min, T max, T mean, T total
    * Peak heap memory usage (MiB): H min, H max, H mean
    * Output num rows per block: N min, N max, N mean, N total
    * Output size bytes per block: N min, N max, N mean, N total
    * Output rows per task: N min, N max, N mean, N tasks used
    * Tasks per node: N min, N max, N mean; N nodes used
    * Operator throughput:
        * Total input num rows: N rows
        * Total output num rows: N rows
        * Ray Data throughput: N rows/s
        * Estimated single task throughput: N rows/s

    Suboperator N RandomShuffleReduce: N tasks executed, N blocks produced
    * Remote wall time: T min, T max, T mean, T total
    * Remote cpu time: T min, T max, T mean, T total
    * UDF time: T min, T max, T mean, T total
    * Peak heap memory usage (MiB): H min, H max, H mean
    * Output num rows per block: N min, N max, N mean, N total
    * Output size bytes per block: N min, N max, N mean, N total
    * Output rows per task: N min, N max, N mean, N tasks used
    * Tasks per node: N min, N max, N mean; N nodes used
    * Operator throughput:
        * Total input num rows: N rows
        * Total output num rows: N rows
        * Ray Data throughput: N rows/s
        * Estimated single task throughput: N rows/s

Operator N Repartition: executed in T

    Suboperator Z RepartitionMap: N tasks executed, N blocks produced
    * Remote wall time: T min, T max, T mean, T total
    * Remote cpu time: T min, T max, T mean, T total
    * UDF time: T min, T max, T mean, T total
    * Peak heap memory usage (MiB): H min, H max, H mean
    * Output num rows per block: N min, N max, N mean, N total
    * Output size bytes per block: N min, N max, N mean, N total
    * Output rows per task: N min, N max, N mean, N tasks used
    * Tasks per node: N min, N max, N mean; N nodes used
    * Operator throughput:
        * Total input num rows: N rows
        * Total output num rows: N rows
        * Ray Data throughput: N rows/s
        * Estimated single task throughput: N rows/s

    Suboperator N RepartitionReduce: N tasks executed, N blocks produced
    * Remote wall time: T min, T max, T mean, T total
    * Remote cpu time: T min, T max, T mean, T total
    * UDF time: T min, T max, T mean, T total
    * Peak heap memory usage (MiB): H min, H max, H mean
    * Output num rows per block: N min, N max, N mean, N total
    * Output size bytes per block: N min, N max, N mean, N total
    * Output rows per task: N min, N max, N mean, N tasks used
    * Tasks per node: N min, N max, N mean; N nodes used
    * Operator throughput:
        * Total input num rows: N rows
        * Total output num rows: N rows
        * Ray Data throughput: N rows/s
        * Estimated single task throughput: N rows/s

Dataset throughput:
    * Ray Data throughput: N rows/s
"""
    )


def test_dataset_stats_repartition(ray_start_regular_shared):
    ds = ray.data.range(1000, override_num_blocks=10)
    ds = ds.repartition(1, shuffle=False)
    stats = ds.materialize().stats()
    assert "Repartition" in stats, stats


def test_dataset_stats_union(ray_start_regular_shared):
    ds = ray.data.range(1000, override_num_blocks=10)
    ds = ds.union(ds)
    stats = ds.materialize().stats()
    assert "Union" in stats, stats


def test_dataset_stats_zip(ray_start_regular_shared):
    ds = ray.data.range(1000, override_num_blocks=10)
    ds = ds.zip(ds)
    stats = ds.materialize().stats()
    assert "Zip" in stats, stats


def test_dataset_stats_sort(ray_start_regular_shared):
    ds = ray.data.range(1000, override_num_blocks=10)
    ds = ds.sort("id")
    stats = ds.materialize().stats()
    assert "SortMap" in stats, stats
    assert "SortReduce" in stats, stats


def test_dataset_stats_from_items(ray_start_regular_shared):
    ds = ray.data.from_items(range(10))
    stats = ds.materialize().stats()
    assert "FromItems" in stats, stats


def test_dataset_stats_range(ray_start_regular_shared, tmp_path):
    ds = ray.data.range(1000, override_num_blocks=10).map(lambda x: x)
    stats = canonicalize(ds.materialize().stats())
    assert stats == (
        f"Operator N ReadRange->Map(<lambda>): {EXECUTION_STRING}\n"
        f"* Remote wall time: T min, T max, T mean, T total\n"
        f"* Remote cpu time: T min, T max, T mean, T total\n"
        f"* UDF time: T min, T max, T mean, T total\n"
        f"* Peak heap memory usage (MiB): H min, H max, H mean\n"
        f"* Output num rows per block: N min, N max, N mean, N total\n"
        f"* Output size bytes per block: N min, N max, N mean, N total\n"
        f"* Output rows per task: N min, N max, N mean, N tasks used\n"
        f"* Tasks per node: N min, N max, N mean; N nodes used\n"
        f"* Operator throughput:\n"
        f"    * Total input num rows: N rows\n"
        f"    * Total output num rows: N rows\n"
        f"    * Ray Data throughput: N rows/s\n"
        f"    * Estimated single task throughput: N rows/s\n"
        f"\n"
        f"Dataset throughput:\n"
        f"    * Ray Data throughput: N rows/s\n"
    )


def test_dataset_split_stats(ray_start_regular_shared, tmp_path, restore_data_context):
    # NOTE: It's critical to preserve ordering for assertions in this test to work
    DataContext.get_current().execution_options.preserve_order = True

    ds = ray.data.range(100, override_num_blocks=10).map(
        column_udf("id", lambda x: x + 1)
    )
    dses = ds.split_at_indices([49])
    dses = [ds.map(column_udf("id", lambda x: x + 1)) for ds in dses]
    for ds_ in dses:
        stats = canonicalize(ds_.materialize().stats())

        assert stats == (
            f"Operator N ReadRange->Map(<lambda>): {EXECUTION_STRING}\n"
            f"* Remote wall time: T min, T max, T mean, T total\n"
            f"* Remote cpu time: T min, T max, T mean, T total\n"
            f"* UDF time: T min, T max, T mean, T total\n"
            f"* Peak heap memory usage (MiB): H min, H max, H mean\n"
            f"* Output num rows per block: N min, N max, N mean, N total\n"
            f"* Output size bytes per block: N min, N max, N mean, N total\n"
            f"* Output rows per task: N min, N max, N mean, N tasks used\n"
            f"* Tasks per node: N min, N max, N mean; N nodes used\n"
            f"* Operator throughput:\n"
            f"    * Total input num rows: N rows\n"
            f"    * Total output num rows: N rows\n"
            f"    * Ray Data throughput: N rows/s\n"
            f"    * Estimated single task throughput: N rows/s\n"
            f"\n"
            f"Operator N Split: {EXECUTION_STRING}\n"
            f"* Remote wall time: T min, T max, T mean, T total\n"
            f"* Remote cpu time: T min, T max, T mean, T total\n"
            f"* UDF time: T min, T max, T mean, T total\n"
            f"* Peak heap memory usage (MiB): H min, H max, H mean\n"
            f"* Output num rows per block: N min, N max, N mean, N total\n"
            f"* Output size bytes per block: N min, N max, N mean, N total\n"
            f"* Output rows per task: N min, N max, N mean, N tasks used\n"
            f"* Tasks per node: N min, N max, N mean; N nodes used\n"
            f"* Operator throughput:\n"
            f"    * Total input num rows: N rows\n"
            f"    * Total output num rows: N rows\n"
            f"    * Ray Data throughput: N rows/s\n"
            f"    * Estimated single task throughput: N rows/s\n"
            f"\n"
            f"Operator N Map(<lambda>): {EXECUTION_STRING}\n"
            f"* Remote wall time: T min, T max, T mean, T total\n"
            f"* Remote cpu time: T min, T max, T mean, T total\n"
            f"* UDF time: T min, T max, T mean, T total\n"
            f"* Peak heap memory usage (MiB): H min, H max, H mean\n"
            f"* Output num rows per block: N min, N max, N mean, N total\n"
            f"* Output size bytes per block: N min, N max, N mean, N total\n"
            f"* Output rows per task: N min, N max, N mean, N tasks used\n"
            f"* Tasks per node: N min, N max, N mean; N nodes used\n"
            f"* Operator throughput:\n"
            f"    * Total input num rows: N rows\n"
            f"    * Total output num rows: N rows\n"
            f"    * Ray Data throughput: N rows/s\n"
            f"    * Estimated single task throughput: N rows/s\n"
            f"\n"
            f"Dataset throughput:\n"
            f"    * Ray Data throughput: N rows/s\n"
        )


def test_calculate_blocks_stats(ray_start_regular_shared, op_two_block):
    block_params, block_meta_list = op_two_block
    stats = DatasetStats(
        metadata={"Read": block_meta_list},
        parent=None,
    )
    calculated_stats = stats.to_summary().operators_stats[0]

    assert calculated_stats.output_num_rows == {
        "min": min(block_params["num_rows"]),
        "max": max(block_params["num_rows"]),
        "mean": np.mean(block_params["num_rows"]),
        "sum": sum(block_params["num_rows"]),
    }
    assert calculated_stats.output_size_bytes == {
        "min": min(block_params["size_bytes"]),
        "max": max(block_params["size_bytes"]),
        "mean": np.mean(block_params["size_bytes"]),
        "sum": sum(block_params["size_bytes"]),
    }
    assert calculated_stats.wall_time == {
        "min": min(block_params["wall_time"]),
        "max": max(block_params["wall_time"]),
        "mean": np.mean(block_params["wall_time"]),
        "sum": sum(block_params["wall_time"]),
    }
    assert calculated_stats.cpu_time == {
        "min": min(block_params["cpu_time"]),
        "max": max(block_params["cpu_time"]),
        "mean": np.mean(block_params["cpu_time"]),
        "sum": sum(block_params["cpu_time"]),
    }

    node_counts = Counter(block_params["node_id"])
    assert calculated_stats.node_count == {
        "min": min(node_counts.values()),
        "max": max(node_counts.values()),
        "mean": np.mean(list(node_counts.values())),
        "count": len(node_counts),
    }


def test_summarize_blocks(ray_start_regular_shared, op_two_block):
    block_params, block_meta_list = op_two_block
    stats = DatasetStats(
        metadata={"Read": block_meta_list},
        parent=None,
    )
    stats.dataset_uuid = "test-uuid"

    calculated_stats = stats.to_summary()
    summarized_lines = calculated_stats.to_string().split("\n")

    latest_end_time = max(m.exec_stats.end_time_s for m in block_meta_list)
    earliest_start_time = min(m.exec_stats.start_time_s for m in block_meta_list)
    assert (
        "Operator 0 Read: 2 tasks executed, 2 blocks produced in {}s".format(
            max(round(latest_end_time - earliest_start_time, 2), 0)
        )
        == summarized_lines[0]
    )
    assert (
        "* Remote wall time: {}s min, {}s max, {}s mean, {}s total".format(
            min(block_params["wall_time"]),
            max(block_params["wall_time"]),
            np.mean(block_params["wall_time"]),
            sum(block_params["wall_time"]),
        )
        == summarized_lines[1]
    )
    assert (
        "* Remote cpu time: {}s min, {}s max, {}s mean, {}s total".format(
            min(block_params["cpu_time"]),
            max(block_params["cpu_time"]),
            np.mean(block_params["cpu_time"]),
            sum(block_params["cpu_time"]),
        )
        == summarized_lines[2]
    )
    assert (
        "* UDF time: {}s min, {}s max, {}s mean, {}s total".format(
            min(block_params["udf_time"]),
            max(block_params["udf_time"]),
            np.mean(block_params["udf_time"]),
            sum(block_params["udf_time"]),
        )
        == summarized_lines[3]
    )
    assert (
        "* Peak heap memory usage (MiB): {} min, {} max, {} mean".format(
            min(block_params["uss_bytes"]) / (1024 * 1024),
            max(block_params["uss_bytes"]) / (1024 * 1024),
            int(np.mean(block_params["uss_bytes"]) / (1024 * 1024)),
        )
        == summarized_lines[4]
    )
    assert (
        "* Output num rows per block: {} min, {} max, {} mean, {} total".format(
            min(block_params["num_rows"]),
            max(block_params["num_rows"]),
            int(np.mean(block_params["num_rows"])),
            sum(block_params["num_rows"]),
        )
        == summarized_lines[5]
    )
    assert (
        "* Output size bytes per block: {} min, {} max, {} mean, {} total".format(
            min(block_params["size_bytes"]),
            max(block_params["size_bytes"]),
            int(np.mean(block_params["size_bytes"])),
            sum(block_params["size_bytes"]),
        )
        == summarized_lines[6]
    )

    assert (
        "* Output rows per task: {} min, {} max, {} mean, {} tasks used".format(
            min(block_params["num_rows"]),
            max(block_params["num_rows"]),
            int(np.mean(list(block_params["num_rows"]))),
            len(set(block_params["task_idx"])),
        )
        == summarized_lines[7]
    )

    node_counts = Counter(block_params["node_id"])
    assert (
        "* Tasks per node: {} min, {} max, {} mean; {} nodes used".format(
            min(node_counts.values()),
            max(node_counts.values()),
            int(np.mean(list(node_counts.values()))),
            len(node_counts),
        )
        == summarized_lines[8]
    )


def test_get_total_stats(ray_start_regular_shared, op_two_block):
    """Tests a set of similar getter methods which pull aggregated
    statistics values after calculating operator-level stats:
    `DatasetStats.get_total_wall_time()`,
    `DatasetStats.get_total_cpu_time()`,
    `DatasetStats.get_max_heap_memory()`."""
    block_params, block_meta_list = op_two_block
    stats = DatasetStats(
        metadata={"Read": block_meta_list},
        parent=None,
    )

    dataset_stats_summary = stats.to_summary()
    op_stats = dataset_stats_summary.operators_stats[0]

    # simple case with only one block / summary, result should match difference between
    # the start and end time
    assert (
        dataset_stats_summary.get_total_wall_time()
        == op_stats.latest_end_time - op_stats.earliest_start_time
    )

    # total time across all blocks is sum of wall times of blocks
    assert dataset_stats_summary.get_total_time_all_blocks() == sum(
        block_params["wall_time"]
    )

    cpu_time_stats = op_stats.cpu_time
    assert dataset_stats_summary.get_total_cpu_time() == cpu_time_stats.get("sum")

    peak_memory_stats = op_stats.memory
    assert dataset_stats_summary.get_max_heap_memory() == peak_memory_stats.get("max")


@pytest.mark.skip(
    reason="Temporarily disable to deflake rest of test suite. "
    "See: https://github.com/ray-project/ray/pull/40173"
)
def test_streaming_stats_full(ray_start_regular_shared, restore_data_context):
    ds = ray.data.range(5, override_num_blocks=5).map(column_udf("id", lambda x: x + 1))
    ds.take_all()
    stats = canonicalize(ds.stats())
    assert (
        stats
        == f"""Operator N ReadRange->Map(<lambda>): {EXECUTION_STRING}
* Remote wall time: T min, T max, T mean, T total
* Remote cpu time: T min, T max, T mean, T total
* UDF time: T min, T max, T mean, T total
* Peak heap memory usage (MiB): H min, H max, H mean
* Output num rows per block: N min, N max, N mean, N total
* Output size bytes per block: N min, N max, N mean, N total
* Output rows per task: N min, N max, N mean, N tasks used
* Tasks per node: N min, N max, N mean; N nodes used
* Operator throughput:
    * Total input num rows: N rows
    * Total output num rows: N rows
    * Ray Data throughput: N rows/s
    * Estimated single task throughput: N rows/s

Dataset iterator time breakdown:
* Total time overall: T
    * Total time in Ray Data iterator initialization code: T
    * Total time user thread is blocked by Ray Data iter_batches: T
    * Total time spent waiting for the first batch after starting iteration: T
    * Total execution time for user thread: T
* Batch iteration time breakdown (summed across prefetch threads):
    * In get RefBundles: T min, T max, T avg, T total
    * In ray.get(): T min, T max, T avg, T total
    * In batch creation: T min, T max, T avg, T total
    * In batch formatting: T min, T max, T avg, T total

Dataset throughput:
    * Ray Data throughput: N rows/s
"""
    )


def test_write_ds_stats(ray_start_regular_shared, tmp_path):
    ds = ray.data.range(100, override_num_blocks=100)
    ds.write_parquet(str(tmp_path))
    stats = ds.stats()

    assert (
        canonicalize(stats)
        == f"""Operator N ReadRange->Write: {EXECUTION_STRING}
* Remote wall time: T min, T max, T mean, T total
* Remote cpu time: T min, T max, T mean, T total
* UDF time: T min, T max, T mean, T total
* Peak heap memory usage (MiB): H min, H max, H mean
* Output num rows per block: N min, N max, N mean, N total
* Output size bytes per block: N min, N max, N mean, N total
* Output rows per task: N min, N max, N mean, N tasks used
* Tasks per node: N min, N max, N mean; N nodes used
* Operator throughput:
    * Total input num rows: N rows
    * Total output num rows: N rows
    * Ray Data throughput: N rows/s
    * Estimated single task throughput: N rows/s

Dataset throughput:
    * Ray Data throughput: N rows/s
"""
    )

    assert stats == ds._write_ds.stats()

    ds = (
        ray.data.range(100, override_num_blocks=100)
        .map_batches(lambda x: x)
        .materialize()
    )
    ds.write_parquet(str(tmp_path))
    stats = ds.stats()

    assert (
        canonicalize(stats)
        == f"""Operator N ReadRange->MapBatches(<lambda>): {EXECUTION_STRING}
* Remote wall time: T min, T max, T mean, T total
* Remote cpu time: T min, T max, T mean, T total
* UDF time: T min, T max, T mean, T total
* Peak heap memory usage (MiB): H min, H max, H mean
* Output num rows per block: N min, N max, N mean, N total
* Output size bytes per block: N min, N max, N mean, N total
* Output rows per task: N min, N max, N mean, N tasks used
* Tasks per node: N min, N max, N mean; N nodes used
* Operator throughput:
    * Total input num rows: N rows
    * Total output num rows: N rows
    * Ray Data throughput: N rows/s
    * Estimated single task throughput: N rows/s

Operator N Write: {EXECUTION_STRING}
* Remote wall time: T min, T max, T mean, T total
* Remote cpu time: T min, T max, T mean, T total
* UDF time: T min, T max, T mean, T total
* Peak heap memory usage (MiB): H min, H max, H mean
* Output num rows per block: N min, N max, N mean, N total
* Output size bytes per block: N min, N max, N mean, N total
* Output rows per task: N min, N max, N mean, N tasks used
* Tasks per node: N min, N max, N mean; N nodes used
* Operator throughput:
    * Total input num rows: N rows
    * Total output num rows: N rows
    * Ray Data throughput: N rows/s
    * Estimated single task throughput: N rows/s

Dataset throughput:
    * Ray Data throughput: N rows/s
"""
    )

    assert stats == ds._write_ds.stats()


def test_time_backpressure(ray_start_regular_shared, restore_data_context):
    class TimedBackpressurePolicy(BackpressurePolicy):
        COUNT = 0

        def can_add_input(self, op: "PhysicalOperator") -> bool:
            if TimedBackpressurePolicy.COUNT > 1:
                time.sleep(0.01)
                return True
            else:
                TimedBackpressurePolicy.COUNT += 1
                return False

    context = DataContext.get_current()
    context.verbose_stats_logs = True
    context.set_config(
        ENABLED_BACKPRESSURE_POLICIES_CONFIG_KEY, [TimedBackpressurePolicy]
    )

    def f(x):
        time.sleep(0.01)
        return x

    ds = ray.data.range(10000).map_batches(f).materialize()
    assert ds._plan.stats().extra_metrics["task_submission_backpressure_time"] > 0


def test_runtime_metrics(ray_start_regular_shared):
    from math import isclose

    def time_to_seconds(time_str):
        if time_str.endswith("us"):
            # Convert microseconds to seconds
            return float(time_str[:-2]) / (1000 * 1000)
        elif time_str.endswith("ms"):
            # Convert milliseconds to seconds
            return float(time_str[:-2]) / 1000
        elif time_str.endswith("s"):
            # Already in seconds, just remove the 's' and convert to float
            return float(time_str[:-1])

    f = dummy_map_batches_sleep(0.01)
    ds = ray.data.range(100).map(f).materialize().map(f).materialize()
    metrics_str = ds._plan.stats().runtime_metrics()

    # Dictionary to store the metrics for testing
    metrics_dict = {}

    # Regular expression to match the pattern of each metric line
    pattern = re.compile(r"\* (.+?): ([\d\.]+(?:ms|s)) \(([\d\.]+)%\)")

    # Split the input string into lines and iterate over them
    for line in metrics_str.split("\n"):
        match = pattern.match(line)
        if match:
            # Extracting the operator name, time, and percentage
            operator_name, time_str, percent_str = match.groups()
            # Converting percentage to float and keeping time as string
            metrics_dict[operator_name] = (
                time_to_seconds(time_str),
                float(percent_str),
            )

    total_time, total_percent = metrics_dict.pop("Total")
    assert total_percent == 100

    for time_s, percent in metrics_dict.values():
        assert time_s <= total_time
        # Check percentage, this is done with some expected loss of precision
        # due to rounding in the intital output.
        assert isclose(percent, time_s / total_time * 100, rel_tol=0.01)


def test_per_node_metrics_basic(ray_start_regular_shared, restore_data_context):
    """Basic test to ensure per-node metrics are populated."""
    ctx = DataContext.get_current()
    ctx.enable_per_node_metrics = True

    def _sum_net_metrics(per_node_metrics: Dict[str, NodeMetrics]) -> Dict[str, float]:
        sum_metrics = defaultdict(float)
        for metrics in per_node_metrics.values():
            for metric, value in metrics.items():
                sum_metrics[metric] += value
        return sum_metrics

    with patch("ray.data._internal.stats.get_or_create_stats_actor") as mock_get_actor:
        mock_actor_handle = MagicMock()
        mock_get_actor.return_value = mock_actor_handle

        ds = ray.data.range(20).map_batches(lambda batch: batch).materialize()
        metrics = ds._plan.stats().extra_metrics

        calls = mock_actor_handle.update_execution_metrics.remote.call_args_list
        assert len(calls) > 0

        last_args, _ = calls[-1]
        per_node_metrics = last_args[-1]

        assert isinstance(per_node_metrics, dict)
        assert len(per_node_metrics) >= 1

        for nm in per_node_metrics.values():
            for f in fields(NodeMetrics):
                assert f.name in nm

        # basic checks to make sure metrics are populated
        assert any(nm["num_tasks_finished"] > 0 for nm in per_node_metrics.values())
        assert any(
            nm["bytes_outputs_of_finished_tasks"] > 0
            for nm in per_node_metrics.values()
        )
        assert any(
            nm["blocks_outputs_of_finished_tasks"] > 0
            for nm in per_node_metrics.values()
        )

        net_metrics = _sum_net_metrics(per_node_metrics)
        assert net_metrics["num_tasks_finished"] == metrics["num_tasks_finished"]
        assert (
            net_metrics["bytes_outputs_of_finished_tasks"]
            == metrics["bytes_outputs_of_finished_tasks"]
        )


@pytest.mark.parametrize("enable_metrics", [True, False])
def test_per_node_metrics_toggle(
    ray_start_regular_shared, restore_data_context, enable_metrics
):
    ctx = DataContext.get_current()
    ctx.enable_per_node_metrics = enable_metrics

    with patch("ray.data._internal.stats.get_or_create_stats_actor") as mock_get_actor:
        mock_actor_handle = MagicMock()
        mock_get_actor.return_value = mock_actor_handle

        ray.data.range(10000).map(lambda x: x).materialize()

        calls = mock_actor_handle.update_execution_metrics.remote.call_args_list
        assert len(calls) > 0

        last_args, _ = calls[-1]
        per_node_metrics = last_args[-1]

        if enable_metrics:
            assert per_node_metrics is not None
        else:
            assert per_node_metrics is None


def test_task_duration_stats():
    """Test that OpTaskDurationStats correctly tracks running statistics using Welford's algorithm."""
    stats = TaskDurationStats()

    # Test initial state
    assert stats.count() == 0
    assert stats.mean() == 0.0
    assert stats.stddev() == 0.0

    # Add some task durations and verify stats
    durations = [2.0, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0]
    for d in durations:
        stats.add_duration(d)

    # Compare with numpy's implementations
    assert stats.count() == len(durations)
    assert pytest.approx(stats.mean()) == np.mean(durations)
    assert pytest.approx(stats.stddev()) == np.std(
        durations, ddof=1
    )  # ddof=1 for sample standard deviation


def test_dataset_throughput_calculation(ray_start_regular_shared):
    """Test throughput calculations using mock block stats."""

    def create_block_stats(start_time, end_time, num_rows):
        exec_stats = BlockExecStats()
        exec_stats.start_time_s = start_time
        exec_stats.end_time_s = end_time
        exec_stats.wall_time_s = end_time - start_time
        exec_stats.cpu_time_s = exec_stats.wall_time_s
        return BlockStats(num_rows=num_rows, size_bytes=None, exec_stats=exec_stats)

    blocks_stats = [
        create_block_stats(0.0, 2.0, 100),
        create_block_stats(0.5, 2.5, 100),
        create_block_stats(1.0, 3.0, 100),
    ]

    stats = DatasetStats(metadata={"Map": blocks_stats}, parent=None)
    summary = stats.to_summary()

    # Throughput: total rows / total execution duration
    # Total rows = 300
    # Duration = max end_time - min start_time = 3.0s
    # 300 rows / 3s = 100 rows/s
    assert summary.num_rows_per_s == 100


def test_operator_throughput_calculation(ray_start_regular_shared):
    """Test operator throughput calculations using mock BlockStats."""

    def create_block_stats(start_time, end_time, num_rows, task_idx):
        exec_stats = BlockExecStats()
        exec_stats.start_time_s = start_time
        exec_stats.end_time_s = end_time
        exec_stats.wall_time_s = end_time - start_time
        exec_stats.cpu_time_s = exec_stats.wall_time_s
        exec_stats.task_idx = task_idx

        return BlockStats(num_rows=num_rows, size_bytes=None, exec_stats=exec_stats)

    blocks_stats = [
        create_block_stats(0.0, 2.0, 100, 1),
        create_block_stats(0.0, 2.0, 100, 2),
    ]

    summary = OperatorStatsSummary.from_block_metadata(
        operator_name="MockOperator",
        block_stats=blocks_stats,
        is_sub_operator=False,
    )

    # Total rows = 200
    # Total operator wall time (from earliest start to latest end) = 2.0s
    # Sum of individual task wall times = 2.0s + 2.0s = 4.0s
    # Overall throughput: Total rows / Total operator wall time
    assert summary.num_rows_per_s == 200 / (2.0 - 0.0)

    # Estimated single task throughput: Total rows / Sum of individual task wall times｀
    assert summary.num_rows_per_task_s == 200 / (2.0 + 2.0)


# NOTE: All tests above share a Ray cluster, while the tests below do not. These
# tests should only be carefully reordered to retain this invariant!


def test_individual_operator_num_rows(shutdown_only):
    # The input num rows of an individual operator should be the same as the output num rows of its parent operator.
    ray.shutdown()
    ray.init(num_cpus=2)

    data = [{"id": i, "value": i * 1.5, "category": i % 5} for i in range(500)]
    ds = (
        ray.data.from_items(data)
        .map(lambda x: {**x, "value_squared": x["value"] ** 2})
        .filter(lambda x: x["value_squared"] > 300)
    )

    stats_output = ds.materialize().stats()
    re_op0_output = re.compile(r"Operator 0.*?Total output num rows: (\d+)", re.DOTALL)
    re_op1_input = re.compile(r"Operator 1.*?Total input num rows: (\d+)", re.DOTALL)

    op0_output = int(re_op0_output.search(stats_output).group(1))
    op1_input = int(re_op1_input.search(stats_output).group(1))

    assert op0_output == 500
    assert op0_output == op1_input


def test_sub_operator_num_rows(shutdown_only):
    # The input num rows of sub operator:
    # The first sub-operator: total output from all parent nodes
    # Subsequent sub-operators: output of the previous sub-operator
    ray.shutdown()
    ray.init(num_cpus=2)

    data1 = [{"id": i, "value1": i * 1.5, "category1": i % 5} for i in range(500)]
    ds1 = ray.data.from_items(data1)
    data2 = [{"id": i, "value2": i * 1.5, "category2": i % 5} for i in range(300)]
    ds2 = ray.data.from_items(data2)
    ds = ds1.join(ds2, join_type="left_outer", num_partitions=2)

    stats_output = ds.materialize().stats()

    patterns = {
        "operator0_output": re.compile(
            r"Operator 0.*?Total output num rows: (\d+)", re.DOTALL
        ),
        "subop0_input": re.compile(
            r"Suboperator 0.*?Total input num rows: (\d+)", re.DOTALL
        ),
        "subop0_output": re.compile(
            r"Suboperator 0.*?Total output num rows: (\d+)", re.DOTALL
        ),
        "subop1_input": re.compile(
            r"Suboperator 1.*?Total input num rows: (\d+)", re.DOTALL
        ),
    }

    extracted_data = {}
    for key, pattern in patterns.items():
        match = pattern.search(stats_output)
        if match:
            extracted_data[key] = int(match.group(1))
        else:
            extracted_data[key] = None

    assert extracted_data["operator0_output"] == 500
    assert extracted_data["subop0_output"] == 800
    assert extracted_data["operator0_output"] == extracted_data["subop0_input"]
    assert extracted_data["subop0_output"] == extracted_data["subop1_input"]


@pytest.mark.parametrize("verbose_stats_logs", [True, False])
def test_spilled_stats(shutdown_only, verbose_stats_logs, restore_data_context):
    context = DataContext.get_current()
    context.verbose_stats_logs = verbose_stats_logs
    context.enable_get_object_locations_for_metrics = True
    # The object store is about 100MB.
    ray.init(object_store_memory=100e6)
    # The size of dataset is 1000*80*80*4*8B, about 200MB.
    ds = ray.data.range(1000 * 80 * 80 * 4).map_batches(lambda x: x).materialize()

    extra_metrics = gen_extra_metrics_str(
        MEM_SPILLED_EXTRA_METRICS_TASK_BACKPRESSURE,
        verbose_stats_logs,
    )
    expected_stats = (
        f"Operator N ReadRange->MapBatches(<lambda>): {EXECUTION_STRING}\n"
        f"* Remote wall time: T min, T max, T mean, T total\n"
        f"* Remote cpu time: T min, T max, T mean, T total\n"
        f"* UDF time: T min, T max, T mean, T total\n"
        f"* Peak heap memory usage (MiB): H min, H max, H mean\n"
        f"* Output num rows per block: N min, N max, N mean, N total\n"
        f"* Output size bytes per block: N min, N max, N mean, N total\n"
        f"* Output rows per task: N min, N max, N mean, N tasks used\n"
        f"* Tasks per node: N min, N max, N mean; N nodes used\n"
        f"* Operator throughput:\n"
        f"    * Total input num rows: N rows\n"
        f"    * Total output num rows: N rows\n"
        f"    * Ray Data throughput: N rows/s\n"
        f"    * Estimated single task throughput: N rows/s\n"
        f"{extra_metrics}\n"
        f"Cluster memory:\n"
        f"* Spilled to disk: M\n"
        f"* Restored from disk: M\n"
        f"\n"
        f"Dataset throughput:\n"
        f"    * Ray Data throughput: N rows/s\n"
        f"{gen_runtime_metrics_str(['ReadRange->MapBatches(<lambda>)'], verbose_stats_logs)}"  # noqa: E501
    )

    assert canonicalize(ds.stats(), filter_global_stats=False) == expected_stats

    # Around 100MB should be spilled (200MB - 100MB)
    assert ds._plan.stats().global_bytes_spilled > 100e6

    ds = (
        ray.data.range(1000 * 80 * 80 * 4)
        .map_batches(lambda x: x)
        .materialize()
        .map_batches(lambda x: x)
        .materialize()
    )
    # two map_batches operators, twice the spillage
    assert ds._plan.stats().dataset_bytes_spilled > 200e6

    # The size of dataset is around 50MB, there should be no spillage
    ds = (
        ray.data.range(250 * 80 * 80 * 4, override_num_blocks=1)
        .map_batches(lambda x: x)
        .materialize()
    )

    assert ds._plan.stats().dataset_bytes_spilled == 0


def test_stats_actor_metrics():
    ray.init(object_store_memory=100e6)
    with patch_update_stats_actor() as update_fn:
        ds = ray.data.range(1000 * 80 * 80 * 4).map_batches(lambda x: x).materialize()

    # last emitted metrics from map operator
    final_metric = update_fn.call_args_list[-1].args[1][-1]

    assert final_metric.obj_store_mem_spilled == ds._plan.stats().dataset_bytes_spilled
    assert (
        final_metric.obj_store_mem_freed
        == ds._plan.stats().extra_metrics["obj_store_mem_freed"]
    )
    assert (
        final_metric.bytes_task_outputs_generated == 1000 * 80 * 80 * 4 * 8
    )  # 8B per int
    assert final_metric.rows_task_outputs_generated == 1000 * 80 * 80 * 4
    # There should be nothing in object store at the end of execution.

    args = update_fn.call_args_list[-1].args
    assert args[0] == f"dataset_{ds._uuid}_0"
    assert args[2][0] == "Input_0"
    assert args[2][1] == "ReadRange->MapBatches(<lambda>)_1"

    def sleep_three(x):
        import time

        time.sleep(3)
        return x

    with patch_update_stats_actor() as update_fn:
        ds = ray.data.range(3).map_batches(sleep_three, batch_size=1).materialize()

    final_metric = update_fn.call_args_list[-1].args[1][-1]
    assert final_metric.block_generation_time >= 9


def test_stats_actor_iter_metrics():
    ds = ray.data.range(1e6).map_batches(lambda x: x)
    with patch_update_stats_actor_iter() as update_fn:
        ds.take_all()

    ds_stats = ds._plan.stats()
    final_stats = update_fn.call_args_list[-1].args[0]

    assert final_stats == ds_stats
    assert f"dataset_{ds._uuid}_0" == update_fn.call_args_list[-1].args[1]


def test_dataset_name_and_id():
    # Test deprecated APIs: _set_name and _name
    ds = ray.data.range(1)
    ds._set_name("test_ds")
    assert ds._name == "test_ds"

    ds = ray.data.range(100, override_num_blocks=20).map_batches(lambda x: x)
    ds.set_name("test_ds")
    assert ds.name == "test_ds"
    assert str(ds) == (
        "MapBatches(<lambda>)\n"
        "+- Dataset(name=test_ds, num_rows=100, schema={id: int64})"
    )

    def _run_dataset(ds, expected_name, expected_run_index):
        with patch_update_stats_actor() as update_fn:
            for _ in ds.iter_batches():
                pass

        assert (
            update_fn.call_args_list[-1].args[0]
            == f"{expected_name}_{ds._uuid}_{expected_run_index}"
        )

    _run_dataset(ds, "test_ds", 0)

    # Run the dataset again, the execution index should be incremented
    _run_dataset(ds, "test_ds", 1)

    # Names persist after an execution
    ds = ds.random_shuffle()
    assert ds.name == "test_ds"
    _run_dataset(ds, "test_ds", 0)

    ds.set_name("test_ds_two")
    ds = ds.map_batches(lambda x: x)
    assert ds.name == "test_ds_two"
    _run_dataset(ds, "test_ds_two", 0)

    ds.set_name(None)
    ds = ds.map_batches(lambda x: x)
    assert ds.name is None
    _run_dataset(ds, "dataset", 0)

    ds = ray.data.range(100, override_num_blocks=20)
    ds.set_name("very_loooooooong_name")
    assert (
        str(ds)
        == "Dataset(name=very_loooooooong_name, num_rows=100, schema={id: int64})"
    )


def test_dataset_id_train_ingest():
    """Test that the dataset ID is properly set for training ingestion jobs."""
    num_epochs = 3
    driver_script = f"""
import ray

ds = ray.data.range(100, override_num_blocks=20).map_batches(lambda x: x)
ds.set_name("train")
ds._set_uuid("1234")

split = ds.streaming_split(1)[0]

for epoch in range({num_epochs}):
    for _ in split.iter_batches():
        pass
    """
    # Need to run the code as s sub process, because the executor
    # runs on the SplitCoordinator actor.
    out = run_string_as_driver(driver_script)
    for i in range(num_epochs):
        dataset_id = f"train_1234_{i}"
        assert f"Starting execution of Dataset {dataset_id}" in out


def test_executor_logs_metrics_on_operator_completion(caplog, propagate_logs):
    """Test that operator completion metrics are logged exactly once per operator."""
    EXPECTED_COMPLETION_MESSAGE = (
        "Operator TaskPoolMapOperator[ReadRange] completed. Operator Metrics:"
    )

    with caplog.at_level(logging.DEBUG):
        ray.data.range(1).take_all()

    log_messages = [record.message for record in caplog.records]
    actual_count = sum(EXPECTED_COMPLETION_MESSAGE in msg for msg in log_messages)
    assert actual_count == 1, (
        f"Expected operator completion message to appear exactly once, "
        f"but found {actual_count} occurrences"
    )


def test_stats_actor_datasets(ray_start_cluster):
    ds = ray.data.range(100, override_num_blocks=20).map_batches(lambda x: x)
    ds.set_name("test_stats_actor_datasets")
    ds.materialize()
    stats_actor = get_or_create_stats_actor()

    datasets = ray.get(stats_actor.get_datasets.remote())
    dataset_name = list(filter(lambda x: x.startswith(ds.name), datasets))
    assert len(dataset_name) == 1
    dataset = datasets[dataset_name[0]]

    assert dataset["state"] == "FINISHED"
    assert dataset["progress"] == 20
    assert dataset["total"] == 20
    assert dataset["end_time"] is not None

    operators = dataset["operators"]
    assert len(operators) == 2
    assert "Input_0" in operators
    assert "ReadRange->MapBatches(<lambda>)_1" in operators
    for value in operators.values():
        assert value["name"] in ["Input", "ReadRange->MapBatches(<lambda>)"]
        assert value["progress"] == 20
        assert value["total"] == 20
        assert value["state"] == "FINISHED"


def test_stats_actor_datasets_eviction(ray_start_cluster):
    """
    Tests that finished datasets are evicted from the _StatsActor when
    the number of datasets exceeds the configured `max_stats` limit.
    """
    # Set a low max_stats limit to easily trigger eviction.
    max_stats = 2
    # Create a dedicated _StatsActor for this test to avoid interfering
    # with the global actor.
    stats_actor = _StatsActor.remote(max_stats=max_stats)

    # Patch the function that retrieves the stats actor to return our
    # test-specific actor instance.
    with patch(
        "ray.data._internal.stats.get_or_create_stats_actor",
        return_value=stats_actor,
    ):

        def check_ds_finished(ds_name):
            """Helper to check if a dataset is marked as FINISHED in the actor."""
            datasets = ray.get(stats_actor.get_datasets.remote())
            ds_tag = next((tag for tag in datasets if tag.startswith(ds_name)), None)
            if not ds_tag:
                return False
            return datasets[ds_tag]["state"] == DatasetState.FINISHED.name

        # --- DS1 ---
        # Create and materialize the first dataset.
        ds1 = ray.data.range(1, override_num_blocks=1)
        ds1.set_name("ds1")
        ds1.materialize()
        # Wait until the actor has been updated with the FINISHED state.
        wait_for_condition(lambda: check_ds_finished("ds1"))

        # --- DS2 ---
        # Create and materialize the second dataset.
        # This brings the total number of datasets to the `max_stats` limit.
        ds2 = ray.data.range(1, override_num_blocks=1)
        ds2.set_name("ds2")
        ds2.materialize()
        wait_for_condition(lambda: check_ds_finished("ds2"))

        # --- Verify state before eviction ---
        # At this point, both ds1 and ds2 should be in the actor.
        datasets = ray.get(stats_actor.get_datasets.remote())
        names_in_actor = {k.split("_")[0] for k in datasets.keys()}
        assert names_in_actor == {"ds1", "ds2"}

        # --- DS3 ---
        # Create and materialize the third dataset. This should trigger the
        # eviction of the oldest finished dataset (ds1).
        ds3 = ray.data.range(1, override_num_blocks=1)
        ds3.set_name("ds3")
        ds3.materialize()

        def check_eviction():
            """
            Helper to check that the actor state reflects the eviction.
            The actor should now contain ds2 and ds3, but not ds1.
            """
            datasets = ray.get(stats_actor.get_datasets.remote())
            # The eviction happens asynchronously, so we might briefly see 3 datasets.
            # We wait until the count is back to 2.
            if len(datasets) == max_stats + 1:
                return False
            names = {k.split("_")[0] for k in datasets.keys()}
            assert names == {"ds2", "ds3"}
            return True

        # Wait until the eviction has occurred and the actor state is correct.
        wait_for_condition(check_eviction)


# Setting internal=10000 (super high number) value so they are only called
# once (on cold start), and on shutdown.
@patch.object(StreamingExecutor, "UPDATE_METRICS_INTERVAL_S", new=10000)
@patch.object(BatchIterator, "UPDATE_METRICS_INTERVAL_S", new=10000)
@patch("ray.data._internal.stats.get_or_create_stats_actor")
def test_stats_manager(mock_get_or_create, shutdown_only):

    # Configure what get_or_create_stats_actor() returns
    mock_actor = MagicMock()
    mock_get_or_create.return_value = mock_actor

    ray.init()
    num_threads = 10

    datasets = [None] * num_threads

    def update_stats_manager(i):
        datasets[i] = ray.data.range(10).map_batches(lambda x: x)
        for _ in datasets[i].iter_batches(batch_size=1):
            pass

    threads = [
        threading.Thread(target=update_stats_manager, args=(i,), daemon=True)
        for i in range(num_threads)
    ]
    for thread in threads:
        thread.start()
    for thread in threads:
        thread.join()

    # Count calls to register_dataset.remote()
    register_dataset_calls = mock_actor.register_dataset.remote.call_count
    # Count calls to update_iteration_metrics.remote()
    iteration_calls = mock_actor.update_iteration_metrics.remote.call_count
    # Count calls to update_execution_metrics.remote()
    execution_calls = mock_actor.update_execution_metrics.remote.call_count

    # Each thread handles 1 dataset.
    assert register_dataset_calls == num_threads

    # Since interval is set to high value, the number of execution
    # calls will update on the first update (cold start), and on shutdown,
    # which is 2 for each thread.
    assert execution_calls == 2 * num_threads
    assert iteration_calls == 2 * num_threads


def test_stats_manager_stale_actor_handle(ray_start_cluster):
    """
    This test asserts that StatsManager is able to handle appropriately
    cases of StatsActor being killed upon driver disconnecting from running
    Ray cluster

    See https://github.com/ray-project/ray/issues/54841 for more details
    """

    class F:
        def __call__(self, x):
            return x

    # First driver run
    ray.init(ignore_reinit_error=True)

    ray.data.range(1000).map_batches(
        F,
        concurrency=(1, 4),
        num_cpus=1,
    ).take_all()

    ray.shutdown()

    # Second driver run
    ray.init(ignore_reinit_error=True)

    ray.data.range(1000).map_batches(
        F,
        concurrency=(1, 4),
        num_cpus=1,
    ).take_all()

    ray.shutdown()


def test_runtime_metrics_histogram_observe():
    """Test that RuntimeMetricsHistogram correctly places values in buckets."""
    # Create a simple histogram with 3 boundaries: [1.0, 5.0, 10.0]
    boundaries = [1.0, 5.0, 10.0]
    histogram = RuntimeMetricsHistogram(boundaries)

    # Test values in different buckets
    histogram.observe(0.5)  # Should go to bucket 0 (< 1.0)
    histogram.observe(3.0)  # Should go to bucket 1 (1.0 <= x < 5.0)
    histogram.observe(7.0)  # Should go to bucket 2 (5.0 <= x < 10.0)
    histogram.observe(15.0)  # Should go to bucket 3 (>= 10.0)

    # Test multiple observations
    histogram.observe(2.0, num_observations=3)  # Should add 3 to bucket 1

    # Verify bucket counts
    expected_counts = [1, 4, 1, 1]  # [bucket0, bucket1, bucket2, bucket3]
    assert histogram._bucket_counts == expected_counts

    # Verify the average value
    assert f"{histogram}" == "(samples: 7, avg: 5.00)"


def test_runtime_metrics_histogram_export_to():
    """Test that export_to correctly applies observations to Ray Histogram."""
    from ray.util.metrics import Histogram

    # Create a simple histogram with 2 boundaries
    boundaries = [1.0, 3.0]
    histogram = RuntimeMetricsHistogram(boundaries)

    # Add some observations
    histogram.observe(0.5)  # bucket 0
    histogram.observe(2.0)  # bucket 1
    histogram.observe(5.0)  # bucket 2

    # Create a mock Ray Histogram
    mock_metric = MagicMock(spec=Histogram)
    mock_metric.last_applied_bucket_counts_for_tags = {}

    # Apply to metric
    tags = {"node_id": "test_node"}
    histogram.export_to(mock_metric, tags)

    # Verify that observe was called 3 times (once for each observation)
    assert mock_metric.observe.call_count == 3

    # Verify the bucket values used for observations are reasonable
    # (should be midpoints of the bucket ranges)
    calls = mock_metric.observe.call_args_list
    observed_values = [call[0][0] for call in calls]  # First argument of each call

    # Check that we have values in the expected ranges
    # Bucket 0: 0 to 1.0, midpoint should be around 0.5
    # Bucket 1: 1.0 to 3.0, midpoint should be around 2.0
    # Bucket 2: 3.0 to 13.0 (3.0 + 10), midpoint should be around 8.0
    assert any(0 <= val <= 1.0 for val in observed_values)
    assert any(1.0 <= val <= 3.0 for val in observed_values)
    assert any(3.0 <= val for val in observed_values)

    # Verify that the last_applied_bucket_counts_for_tags was updated
    tags_key = '{"node_id": "test_node"}'
    assert tags_key in mock_metric.last_applied_bucket_counts_for_tags
    assert mock_metric.last_applied_bucket_counts_for_tags[tags_key] == [1, 1, 1]

    # Add some more observations
    histogram.observe(0.8)  # bucket 0
    histogram.observe(1.2)  # bucket 1
    histogram.export_to(mock_metric, tags)

    # Verify that observe was called 2 more times (once for each observation)
    assert mock_metric.observe.call_count == 5

    # Verify the bucket values used for observations are reasonable
    # (should be midpoints of the bucket ranges)
    calls = mock_metric.observe.call_args_list
    observed_values = [call[0][0] for call in calls[2:]]  # First argument of each call

    # Check that we have values in the expected ranges
    # Bucket 0: 0 to 1.0, midpoint should be around 0.5
    # Bucket 1: 1.0 to 3.0, midpoint should be around 2.0
    assert any(0 <= val <= 1.0 for val in observed_values)
    assert any(1.0 <= val <= 3.0 for val in observed_values)

    assert mock_metric.last_applied_bucket_counts_for_tags[tags_key] == [2, 2, 1]


if __name__ == "__main__":
    import sys

    sys.exit(pytest.main(["-vv", __file__]))
